from datetime import timedelta

def init_agent(args, device, flag, modelName=None):
    
    from accelerate import Accelerator
    from accelerate import DistributedDataParallelKwargs, InitProcessGroupKwargs

    ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
    accelerator = Accelerator(InitProcessGroupKwargs(timeout=timedelta(minutes=40)), kwargs_handlers=[ddp_kwargs])
    if "OS_Atlas" in modelName:
        from gui_speaker.models.OS_ATLAS import OS_ATLAS_Agent
        agent = OS_ATLAS_Agent(device=device, 
                              accelerator=accelerator, 
                              policy_lm=args.model_path)
    elif "UI-TARS" in modelName:
        from gui_speaker.models.UI_TARS import UI_TARS_Agent
        agent = UI_TARS_Agent(device=device, 
                              accelerator=accelerator, 
                              policy_lm=args.model_path)
    elif "GUI-R1" in modelName:
        from gui_speaker.models.GUI_R1 import GUI_R1_Agent
        agent = GUI_R1_Agent(device=device, 
                              accelerator=accelerator, 
                              policy_lm=args.model_path)
    elif "Agent-CPM" in modelName:
        from gui_speaker.models.AgentCPM import AgentCPM_GUI_Agent
        agent = AgentCPM_GUI_Agent(device=device,
                                   accelerator=accelerator,
                                   policy_lm=args.model_path)
    elif "GUI-Owl" in modelName:
        from gui_speaker.models.GUI_OWL import GUI_OWL_Agent
        agent = GUI_OWL_Agent(device=device,
                                   accelerator=accelerator,
                                   policy_lm=args.model_path)
    if flag:
        agent.model = agent._load_model()
        agent.model = accelerator.prepare(agent.model)
    return agent